package com.alibaba.datax.plugin.reader.oraclereader.impl.util;

import com.alibaba.datax.common.exception.DataXException;
import com.alibaba.datax.common.util.Configuration;
import com.alibaba.datax.plugin.reader.oraclereader.impl.Constant;
import com.alibaba.datax.plugin.reader.oraclereader.impl.Key;
import com.alibaba.datax.plugin.reader.oraclereader.util.*;
import com.alibaba.fastjson.JSON;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.math.BigInteger;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.Types;
import java.util.ArrayList;
import java.util.List;

public class SingleTableSplitUtil {
    private static final Logger LOG = LoggerFactory
            .getLogger(SingleTableSplitUtil.class);

    public static DataBaseType DATABASE_TYPE;

    private SingleTableSplitUtil() {
    }

    private static List<Configuration> assemble(Configuration configuration,
                                                String partition,
                                                String splitPkName,
                                                String column,
                                                String table,
                                                String where,
                                                String scn,
                                                List<String> allQuerySql,
                                                List<String> rangeList) {
        List<Configuration> pluginParams = new ArrayList<Configuration>();

        boolean hasWhere = StringUtils.isNotBlank(where);
        String tempQuerySql;

        if (null != rangeList && !rangeList.isEmpty()) {
            for (String range : rangeList) {
                Configuration tempConfig = configuration.clone();

                tempQuerySql = buildQuerySql(column, table, partition, scn, where)
                        + (hasWhere ? " and " : " where ") + range;

                allQuerySql.add(tempQuerySql);
                tempConfig.set(Key.QUERY_SQL, tempQuerySql);
                pluginParams.add(tempConfig);
            }

            // deal pk is null
            Configuration tempConfig = configuration.clone();
            tempQuerySql = buildQuerySql(column, table, partition, scn, where)
                    + (hasWhere ? " and " : " where ")
                    + String.format(" %s IS NULL", splitPkName);

            allQuerySql.add(tempQuerySql);
            tempConfig.set(Key.QUERY_SQL, tempQuerySql);
            pluginParams.add(tempConfig);

        } else {
            // rangeList为空或者null的时候, 可以不切块
            //pluginParams.add(configuration); // this is wrong for new & old split
            Configuration tempConfig = configuration.clone();
            tempQuerySql = buildQuerySql(column, table, partition, scn, where);
            allQuerySql.add(tempQuerySql);
            tempConfig.set(Key.QUERY_SQL, tempQuerySql);
            pluginParams.add(tempConfig);
        }
        return pluginParams;
    }

    public static List<Configuration> splitSingleTable(
            Configuration configuration, int adviceNum) {
        List<Configuration> pluginParams = new ArrayList<Configuration>();

        String splitPkName = configuration.getString(Key.SPLIT_PK);
        String column = configuration.getString(Key.COLUMN);
        String table = configuration.getString(Key.TABLE);
        String where = configuration.getString(Key.WHERE, null);
        String scn = configuration.getString(Key.SCN);
        List<String> partitions = configuration.getList(Key.PARTITION, String.class);

        List<String> allQuerySql = new ArrayList<String>();

        // todo 单表的也可以根据表中记录数来决定切分多少块
        // all_tables num_rows
        Integer pageSize = configuration.getInt(Key.PAGE_SIZE, 100*10000);
        if (partitions == null || partitions.isEmpty()) {
            List<String> rangeList = null;
            int numRows = numRows(table, null, scn, where, configuration);
            int an = adviceNum;
            if (numRows >= pageSize) {
                an = numRows / pageSize;
            }
            rangeList = genSplitSqlForOracle(splitPkName, table, null, scn, where, configuration, an);

            List<Configuration> configurations = assemble(configuration, null, splitPkName, column, table, where, scn, allQuerySql, rangeList);
            pluginParams.addAll(configurations);
        } else {
            for (String partition : partitions) {
                List<String> rangeList = null;
                int numRows = numRows(table, partition, scn, where, configuration);
                int an = adviceNum;
                if (numRows >= pageSize) {
                    an = numRows / pageSize;
                }
                rangeList = genSplitSqlForOracle(splitPkName, table, partition, scn, where, configuration, an);
                List<Configuration> configurations = assemble(configuration, partition, splitPkName, column, table, where, scn, allQuerySql, rangeList);
                pluginParams.addAll(configurations);
            }
        }

        LOG.info("After split(), allQuerySql=[\n{}\n].",
                StringUtils.join(allQuerySql, "\n"));

        return pluginParams;
    }

    public static String buildPageQuerySql(String column, String table, String partition, String scn, String where,
                                           Pair<Integer, Integer> pair) {
        StringBuilder sb = new StringBuilder();
        sb.append("select ")
                .append(column)
                .append(" from ");
        sb.append("(");
            sb.append("select t_subquery.*, rownum rn from ");
            sb.append("(");
                sb.append(buildQuerySql(column, table, partition, scn, where));
            sb.append(") t_subquery");
            if (pair.getRight() != null) {
                sb.append(" where rownum < ").append(pair.getRight());
            }
        sb.append(") where rn >= ").append(pair.getLeft());
        return sb.toString();
    }

    public static String buildQuerySql(String column, String table, String partition, String scn, String where) {
        StringBuilder sb = new StringBuilder();
        sb.append("select ")
                .append(column)
                .append(" from ")
                .append(table);
        if (StringUtils.isNotBlank(partition)) {
            sb.append(" partition(\"").append(partition).append("\")");
        }
        if (StringUtils.isNotBlank(scn)) {
            sb.append(" as of scn ").append(scn);
        }
        if (StringUtils.isNotBlank(where)) {
            sb.append(" where (").append(where).append(")");
        }

        String querySql = sb.toString();
        return querySql;
    }

    @SuppressWarnings("resource")
    private static Pair<Object, Object> getPkRange(Configuration configuration) {
        String pkRangeSQL = genPKRangeSQL(configuration);

        int fetchSize = configuration.getInt(Constant.FETCH_SIZE);
        String jdbcURL = configuration.getString(Key.JDBC_URL);
        String username = configuration.getString(Key.USERNAME);
        String password = configuration.getString(Key.PASSWORD);
        String table = configuration.getString(Key.TABLE);

        Connection conn = DBUtil.getConnection(DATABASE_TYPE, jdbcURL, username, password);
        Pair<Object, Object> minMaxPK = checkSplitPk(conn, pkRangeSQL, fetchSize, table, username, configuration);
        DBUtil.closeDBResources(null, null, conn);
        return minMaxPK;
    }

    public static void precheckSplitPk(Connection conn, String pkRangeSQL, int fetchSize,
                                                       String table, String username) {
        Pair<Object, Object> minMaxPK = checkSplitPk(conn, pkRangeSQL, fetchSize, table, username, null);
        if (null == minMaxPK) {
            throw DataXException.asDataXException(DBUtilErrorCode.ILLEGAL_SPLIT_PK,
                    "根据切分主键切分表失败. DataX 仅支持切分主键为一个,并且类型为整数或者字符串类型. 请尝试使用其他的切分主键或者联系 DBA 进行处理.");
        }
    }

    /**
     * 检测splitPk的配置是否正确。
     * configuration为null, 是precheck的逻辑，不需要回写PK_TYPE到configuration中
     *
     */
    private static Pair<Object, Object> checkSplitPk(Connection conn, String pkRangeSQL, int fetchSize,  String table,
                                                     String username, Configuration configuration) {
        LOG.info("split pk [sql={}] is running... ", pkRangeSQL);
        ResultSet rs = null;
        Pair<Object, Object> minMaxPK = null;
        try {
            try {
                rs = DBUtil.query(conn, pkRangeSQL, fetchSize);
            }catch (Exception e) {
                throw RdbmsException.asQueryException(DATABASE_TYPE, e, pkRangeSQL,table,username);
            }
            ResultSetMetaData rsMetaData = rs.getMetaData();
            if (isPKTypeValid(rsMetaData)) {
                if (isStringType(rsMetaData.getColumnType(1))) {
                    if(configuration != null) {
                        configuration
                                .set(Constant.PK_TYPE, Constant.PK_TYPE_STRING);
                    }
                    while (DBUtil.asyncResultSetNext(rs)) {
                        minMaxPK = new ImmutablePair<Object, Object>(
                                rs.getString(1), rs.getString(2));
                    }
                } else if (isLongType(rsMetaData.getColumnType(1))) {
                    if(configuration != null) {
                        configuration.set(Constant.PK_TYPE, Constant.PK_TYPE_LONG);
                    }

                    while (DBUtil.asyncResultSetNext(rs)) {
                        minMaxPK = new ImmutablePair<Object, Object>(
                                rs.getString(1), rs.getString(2));

                        // check: string shouldn't contain '.', for oracle
                        String minMax = rs.getString(1) + rs.getString(2);
                        if (StringUtils.contains(minMax, '.')) {
                            throw DataXException.asDataXException(DBUtilErrorCode.ILLEGAL_SPLIT_PK,
                                    "您配置的DataX切分主键(splitPk)有误. 因为您配置的切分主键(splitPk) 类型 DataX 不支持. DataX 仅支持切分主键为一个,并且类型为整数或者字符串类型. 请尝试使用其他的切分主键或者联系 DBA 进行处理.");
                        }
                    }
                } else {
                    throw DataXException.asDataXException(DBUtilErrorCode.ILLEGAL_SPLIT_PK,
                            "您配置的DataX切分主键(splitPk)有误. 因为您配置的切分主键(splitPk) 类型 DataX 不支持. DataX 仅支持切分主键为一个,并且类型为整数或者字符串类型. 请尝试使用其他的切分主键或者联系 DBA 进行处理.");
                }
            } else {
                throw DataXException.asDataXException(DBUtilErrorCode.ILLEGAL_SPLIT_PK,
                        "您配置的DataX切分主键(splitPk)有误. 因为您配置的切分主键(splitPk) 类型 DataX 不支持. DataX 仅支持切分主键为一个,并且类型为整数或者字符串类型. 请尝试使用其他的切分主键或者联系 DBA 进行处理.");
            }
        } catch(DataXException e) {
            throw e;
        } catch (Exception e) {
            throw DataXException.asDataXException(DBUtilErrorCode.ILLEGAL_SPLIT_PK, "DataX尝试切分表发生错误. 请检查您的配置并作出修改.", e);
        } finally {
            DBUtil.closeDBResources(rs, null, null);
        }

        return minMaxPK;
    }

    private static boolean isPKTypeValid(ResultSetMetaData rsMetaData) {
        boolean ret = false;
        try {
            int minType = rsMetaData.getColumnType(1);
            int maxType = rsMetaData.getColumnType(2);

            boolean isNumberType = isLongType(minType);

            boolean isStringType = isStringType(minType);

            if (minType == maxType && (isNumberType || isStringType)) {
                ret = true;
            }
        } catch (Exception e) {
            throw DataXException.asDataXException(DBUtilErrorCode.ILLEGAL_SPLIT_PK,
                    "DataX获取切分主键(splitPk)字段类型失败. 该错误通常是系统底层异常导致. 请联系旺旺:askdatax或者DBA处理.");
        }
        return ret;
    }

    // warn: Types.NUMERIC is used for oracle! because oracle use NUMBER to
    // store INT, SMALLINT, INTEGER etc, and only oracle need to concern
    // Types.NUMERIC
    private static boolean isLongType(int type) {
        boolean isValidLongType = type == Types.BIGINT || type == Types.INTEGER
                || type == Types.SMALLINT || type == Types.TINYINT;

        switch (SingleTableSplitUtil.DATABASE_TYPE) {
            case Oracle:
                isValidLongType |= type == Types.NUMERIC;
                break;
            default:
                break;
        }
        return isValidLongType;
    }
    
    private static boolean isStringType(int type) {
        return type == Types.CHAR || type == Types.NCHAR
                || type == Types.VARCHAR || type == Types.LONGVARCHAR
                || type == Types.NVARCHAR;
    }

    private static String genPKRangeSQL(Configuration configuration) {

        String splitPK = configuration.getString(Key.SPLIT_PK).trim();
        String table = configuration.getString(Key.TABLE).trim();
        String where = configuration.getString(Key.WHERE, null);
        return genPKSql(splitPK,table,where);
    }

    public static String genPKSql(String splitPK, String table, String where){

        String minMaxTemplate = "SELECT MIN(%s),MAX(%s) FROM %s";
        String pkRangeSQL = String.format(minMaxTemplate, splitPK, splitPK,
                table);
        if (StringUtils.isNotBlank(where)) {
            pkRangeSQL = String.format("%s WHERE (%s AND %s IS NOT NULL)",
                    pkRangeSQL, where, splitPK);
        }
        return pkRangeSQL;
    }

    public static String buildNumRowsSql(String table, String partition, String scn, String where) {
        String tableName = null;
        String schemaName = null;
        if (table.contains(".")) {
            String[] names = table.split("\\.");
            schemaName = names[0].replaceAll("\"", "");
            tableName = names[1].replaceAll("\"", "");
        } else {
            tableName = table.replaceAll("\"", "");
        }
        StringBuilder sb = new StringBuilder();
        sb.append("SELECT NUM_ROWS FROM ");
        if (StringUtils.isNotBlank(partition)) {
            sb.append("ALL_TAB_PARTITIONS ");
        } else {
            sb.append("ALL_TABLES ");
        }
        sb.append("WHERE ");
        if (StringUtils.isNotBlank(schemaName)) {
            if (StringUtils.isNotBlank(partition)) {
                sb.append("TABLE_OWNER = ").append("'").append(schemaName).append("'");
            } else {
                sb.append("OWNER = ").append("'").append(schemaName).append("'");
            }
            sb.append(" AND ");
        }
        sb.append("TABLE_NAME = ");
        sb.append("'").append(tableName).append("'");
        if (StringUtils.isNotBlank(partition)) {
            sb.append(" AND ");
            sb.append("PARTITION_NAME = ");
            sb.append("'").append(partition).append("'");
        }
        return sb.toString();
    }

    public static Integer numRows(String table, String partition, String scn,
                                        String where, Configuration configuration) {
        String countSql = buildNumRowsSql(table, partition, scn, where);
        String jdbcURL = configuration.getString(Key.JDBC_URL);
        String username = configuration.getString(Key.USERNAME);
        String password = configuration.getString(Key.PASSWORD);
        Connection conn = null;
        ResultSet rs = null;
        int count = 0;
        try {
            conn = DBUtil.getConnection(DATABASE_TYPE, jdbcURL, username, password);
            rs = DBUtil.query(conn, countSql);
            if (rs.next()) {
                count = rs.getInt(1);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            DBUtil.closeDBResources(rs, null, conn);
        }
        return count;
    }

    public static List<Pair<Integer, Integer>> genSplitRowNum(String table, String partition, String scn,
                                                    String where, Configuration configuration,
                                                    Integer pageSize) {
        List<Pair<Integer, Integer>> rns = new ArrayList<>();
        String countSql = buildNumRowsSql(table, partition, scn, where);
        String jdbcURL = configuration.getString(Key.JDBC_URL);
        String username = configuration.getString(Key.USERNAME);
        String password = configuration.getString(Key.PASSWORD);
        Connection conn = null;
        ResultSet rs = null;
        try {
            conn = DBUtil.getConnection(DATABASE_TYPE, jdbcURL, username, password);
            rs = DBUtil.query(conn, countSql);
            int count = 0;
            if (rs.next()) {
                count = rs.getInt(1);
            }
            int pn = count / pageSize;
            if (pn > 1) {
                int ps = count / pn;
                for (int current=1; current<pn+1; current++) {
                    Pair<Integer, Integer> pair = null;
                    if (current < pn) {
                        pair = Pair.of(ps * (current - 1) + 1, ps * current + 1);
                    } else {
                        pair = Pair.of(ps * (current - 1) + 1, null);
                    }
                    rns.add(pair);
                }
            }
        } catch (Exception e) {
            // todo
            throw new RuntimeException(e);
        } finally {
            DBUtil.closeDBResources(rs, null, conn);
        }
        return rns;
    }
    
    /**
     * support Number and String split
     * */
    public static List<String> genSplitSqlForOracle(String splitPK,
            String table, String partition, String scn, String where, Configuration configuration,
            int adviceNum) {
        if (adviceNum < 1) {
            throw new IllegalArgumentException(String.format(
                    "切分份数不能小于1. 此处:adviceNum=[%s].", adviceNum));
        } else if (adviceNum == 1) {
            return null;
        }

        String whereSql = String.format("%s IS NOT NULL", splitPK);
        if (StringUtils.isNotBlank(where)) {
            whereSql = String.format(" WHERE (%s) AND (%s) ", whereSql, where);
        } else {
            whereSql = String.format(" WHERE (%s) ", whereSql);
        }
        Double percentage = configuration.getDouble(Key.SAMPLE_PERCENTAGE, 0.1);
//        String sampleSqlTemplate = "SELECT * FROM ( SELECT %s FROM %s SAMPLE (%s) %s ORDER BY DBMS_RANDOM.VALUE) WHERE ROWNUM <= %s ORDER by %s ASC";
        StringBuilder sb = new StringBuilder();
        sb.append("SELECT * FROM ( SELECT %s FROM %s ");
        if (StringUtils.isNotBlank(partition)) {
            sb.append("partition(\"").append(partition).append("\") ");
        }
        sb.append("SAMPLE (%s) ");
        if (StringUtils.isNotBlank(scn)) {
            sb.append("as of scn ").append(scn).append(" ");
        }
        sb.append("%s ORDER BY DBMS_RANDOM.VALUE) WHERE ROWNUM <= %s ORDER by %s ASC");

        String sampleSqlTemplate = sb.toString();
        String splitSql = String.format(sampleSqlTemplate, splitPK, table,
        percentage, whereSql, adviceNum, splitPK);

        int fetchSize = configuration.getInt(Constant.FETCH_SIZE, 32);
        String jdbcURL = configuration.getString(Key.JDBC_URL);
        String username = configuration.getString(Key.USERNAME);
        String password = configuration.getString(Key.PASSWORD);
        Connection conn = DBUtil.getConnection(DATABASE_TYPE, jdbcURL,
                username, password);
        LOG.info("split pk [sql={}] is running... ", splitSql);
        ResultSet rs = null;
        List<Pair<Object, Integer>> splitedRange = new ArrayList<Pair<Object, Integer>>();
        try {
            try {
                rs = DBUtil.query(conn, splitSql, fetchSize);
            } catch (Exception e) {
                throw RdbmsException.asQueryException(DATABASE_TYPE, e,
                        splitSql, table, username);
            }
            if (configuration != null) {
                configuration
                        .set(Constant.PK_TYPE, Constant.PK_TYPE_MONTECARLO);
            }
            ResultSetMetaData rsMetaData = rs.getMetaData();
            while (DBUtil.asyncResultSetNext(rs)) {
                ImmutablePair<Object, Integer> eachPoint = new ImmutablePair<Object, Integer>(
                        rs.getObject(1), rsMetaData.getColumnType(1));
                splitedRange.add(eachPoint);
            }
        } catch (DataXException e) {
            throw e;
        } catch (Exception e) {
            throw DataXException.asDataXException(
                    DBUtilErrorCode.ILLEGAL_SPLIT_PK,
                    "DataX尝试切分表发生错误. 请检查您的配置并作出修改.", e);
        } finally {
            DBUtil.closeDBResources(rs, null, conn);

        }
        LOG.debug(JSON.toJSONString(splitedRange));
        List<String> rangeSql = new ArrayList<String>();
        int splitedRangeSize = splitedRange.size();
        // warn: splitedRangeSize may be 0 or 1，切分规则为IS NULL以及 IS NOT NULL
        // demo: Parameter rangeResult can not be null and its length can not <2. detail:rangeResult=[24999930].
        if (splitedRangeSize >= 2) {
            // warn: oracle Number is long type here
            if (isLongType(splitedRange.get(0).getRight())) {
                BigInteger[] integerPoints = new BigInteger[splitedRange.size()];
                for (int i = 0; i < splitedRangeSize; i++) {
                    integerPoints[i] = new BigInteger(splitedRange.get(i)
                            .getLeft().toString());
                }
                rangeSql.addAll(RdbmsRangeSplitWrap.wrapRange(integerPoints,
                        splitPK));
                // its ok if splitedRangeSize is 1
                rangeSql.add(RdbmsRangeSplitWrap.wrapFirstLastPoint(
                        integerPoints[0], integerPoints[splitedRangeSize - 1],
                        splitPK));
            } else if (isStringType(splitedRange.get(0).getRight())) {
                // warn: treated as string type
                String[] stringPoints = new String[splitedRange.size()];
                for (int i = 0; i < splitedRangeSize; i++) {
                    stringPoints[i] = new String(splitedRange.get(i).getLeft()
                            .toString());
                }
                rangeSql.addAll(RdbmsRangeSplitWrap.wrapRange(stringPoints,
                        splitPK, "'", DATABASE_TYPE));
                // its ok if splitedRangeSize is 1
                rangeSql.add(RdbmsRangeSplitWrap.wrapFirstLastPoint(
                        stringPoints[0], stringPoints[splitedRangeSize - 1],
                        splitPK, "'", DATABASE_TYPE));
            } else {
                throw DataXException
                        .asDataXException(
                                DBUtilErrorCode.ILLEGAL_SPLIT_PK,
                                "您配置的DataX切分主键(splitPk)有误. 因为您配置的切分主键(splitPk) 类型 DataX 不支持. DataX 仅支持切分主键为一个,并且类型为整数或者字符串类型. 请尝试使用其他的切分主键或者联系 DBA 进行处理.");
            }
        }
        return rangeSql;
    }
}